{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "23083b99",
   "metadata": {},
   "source": [
    "# Torch jit trace test for quantizer"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ebfede57",
   "metadata": {},
   "source": [
    "### 1. How to capture whole graph in quantizer ?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "88e243ca",
   "metadata": {},
   "source": [
    "Quantizer capture the whole graph by using tracing in PyTorch. The tracing is an export method. It runs a model with example inputs, recording the operations performed on all the tensors. Quantizer use two different Pytorch API to get tracing graph. One is \"_get_trace_graph\" which is used to get graph from model without control flow. This internal API was designed earlier than \"torch.jit.trace\" for onnx exporting. The other one is \"torch.jit.trace\", it is used to get graph from model with control flow. Of course, the control flow part is scripted using \"@script_if_tracing\".Typically, this only requires a small refactor of the forward fuction to separate the control flow parts that need to be compiled.That does not means we fully support the torch script.For quantizer requirments, we should use tracing for the majority of logic, and use scripting only when necessary."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "32519a72",
   "metadata": {},
   "source": [
    "### 2. The problems related with tracing you should pay attention to "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6edd0d97",
   "metadata": {},
   "source": [
    "##### 1. Dynamic Control flow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4bc7e679",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):\n",
      "  %1 : Float(3, strides=[1], requires_grad=0, device=cpu) = aten::relu(%x) # /proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/nn/functional.py:1457:0\n",
      "  return (%1)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/ipykernel_launcher.py:3: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "def f(x):\n",
    "    return torch.nn.functional.relu(x) if x.sum() > 0 else torch.nn.functional.relu6(x)\n",
    "\n",
    "traced_script = torch.jit.trace(f, torch.randn(3))\n",
    "print(traced_script.inlined_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d60ca16a",
   "metadata": {},
   "source": [
    "In this example, the trace only keeps one branch of control flow which is depend on the concret inputs. If we truely want to preserve the control flow in the function, we can use the \"@script_if_tracing\"."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "3db65044",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph(%x : Float(3, strides=[1], requires_grad=0, device=cpu)):\n",
      "  %1 : Function = prim::Constant[name=\"f\"]()\n",
      "  %3 : int = prim::Constant[value=0]() # /tmp/ipykernel_28127/3830558629.py:4:52\n",
      "  %4 : NoneType = prim::Constant()\n",
      "  %5 : Tensor = aten::sum(%x, %4) # /tmp/ipykernel_28127/3830558629.py:4:42\n",
      "  %6 : Tensor = aten::gt(%5, %3) # /tmp/ipykernel_28127/3830558629.py:4:42\n",
      "  %7 : bool = aten::Bool(%6) # /tmp/ipykernel_28127/3830558629.py:4:42\n",
      "  %8 : Tensor = prim::If(%7) # /tmp/ipykernel_28127/3830558629.py:4:11\n",
      "    block0():\n",
      "      %result.6 : Tensor = aten::relu(%x) # /proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/nn/functional.py:1457:17\n",
      "      -> (%result.6)\n",
      "    block1():\n",
      "      %result.3 : Tensor = aten::relu6(%x) # /proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/nn/functional.py:1534:17\n",
      "      -> (%result.3)\n",
      "  return (%8)\n",
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "@torch.jit.script_if_tracing\n",
    "def f(x):\n",
    "    return torch.nn.functional.relu(x) if x.sum() > 0 else torch.nn.functional.relu6(x)\n",
    "\n",
    "traced_script = torch.jit.trace(f, torch.randn(3))\n",
    "print(traced_script.inlined_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9f141ee9",
   "metadata": {},
   "source": [
    "##### 2. Freeze variables as constants"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "69aed12a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def f(x: Tensor) -> Tensor:\n",
      "  _0 = torch.arange(1, dtype=None, layout=0, device=torch.device(\"cpu\"), pin_memory=False)\n",
      "  return _0\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/ipykernel_launcher.py:3: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.rand(1)\n",
    "y = torch.rand(2)\n",
    "def f(x): return torch.arange(len(x))\n",
    "traced_script = torch.jit.trace(f, x)\n",
    "print(traced_script.code)\n",
    "traced_script(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70a7c1de",
   "metadata": {},
   "source": [
    "Intermediate computation results of a non-Tensor type (in this case, an int type) may be frozen as constants, using the value observed during tracing. This causes the trace to not generalize. we should use symbolic shapes instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "59b4e56e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def f(x: Tensor) -> Tensor:\n",
      "  _0 = ops.prim.NumToTensor(torch.size(x, 0))\n",
      "  _1 = torch.arange(annotate(number, _0), dtype=None, layout=0, device=torch.device(\"cpu\"), pin_memory=False)\n",
      "  return _1\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "tensor([0, 1])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "x = torch.rand(1)\n",
    "y = torch.rand(2)\n",
    "def f(x): return torch.arange(x.size(0))\n",
    "traced_script = torch.jit.trace(f, x)\n",
    "print(traced_script.code)\n",
    "traced_script(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e2935038",
   "metadata": {},
   "source": [
    "##### 3. Freeze device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "af018ab3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "def f(x: Tensor) -> Tensor:\n",
      "  return torch.to(x, torch.device(\"cpu\"), 6)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/ipykernel_launcher.py:3: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "def f(x):\n",
    "    return torch.as_tensor(x, device=x.device)\n",
    "traced_script = torch.jit.trace(f, torch.randn(2))\n",
    "print(traced_script.code)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcca37c0",
   "metadata": {},
   "source": [
    "The device attribute of input will be frozen during tracing.The trace script may not generalize to inputs on a different device. Such generalization is almost never needed, because deployment usually has a target device."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b333f836",
   "metadata": {},
   "source": [
    "##### 4. Input/output format"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ece0c753",
   "metadata": {},
   "source": [
    "Model's inputs/outputs have to be \"Union[Tensor, Tuple[Tensor]]\" to be traceable.The format requirement only applies to the outer-most model, so it's very easy to address. If the model uses richer formats such \"Dict[str, tensor]\", just create a simple wrapper around it that converts to/from Tuple[Tensor]. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "94dbea61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': tensor(1.4142), 'b': tensor(9.)}\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_28127/3918274742.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     12\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mRichFormatModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 14\u001b[0;31m \u001b[0mtrace_script\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mRichFormatModel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     15\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace_script\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minlined_graph\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/jit/_trace.py\u001b[0m in \u001b[0;36mtrace\u001b[0;34m(func, example_inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)\u001b[0m\n\u001b[1;32m    757\u001b[0m             \u001b[0mstrict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    758\u001b[0m             \u001b[0m_force_outplace\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 759\u001b[0;31m             \u001b[0m_module_class\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    760\u001b[0m         )\n\u001b[1;32m    761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/proj/rdi/staff/wluo/tools/anaconda3/envs/torch1.12/lib/python3.7/site-packages/torch/jit/_trace.py\u001b[0m in \u001b[0;36mtrace_module\u001b[0;34m(mod, inputs, optimize, check_trace, check_inputs, check_tolerance, strict, _force_outplace, _module_class, _compilation_unit)\u001b[0m\n\u001b[1;32m    972\u001b[0m                 \u001b[0mstrict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    973\u001b[0m                 \u001b[0m_force_outplace\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 974\u001b[0;31m                 \u001b[0margument_names\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    975\u001b[0m             )\n\u001b[1;32m    976\u001b[0m             \u001b[0mcheck_trace_method\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_c\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmethod_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: Encountering a dict at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior."
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from typing import Dict\n",
    "class RichFormatModel(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    def forward(self, x:Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:\n",
    "        y = {}\n",
    "        y[\"a\"] = torch.sqrt(x[\"a\"])\n",
    "        y[\"b\"] = torch.square(x[\"b\"])\n",
    "        return y\n",
    "input = {\"a\": torch.tensor(2.0), \"b\": torch.tensor(3.0)}\n",
    "output = RichFormatModel()(input)\n",
    "print(output)\n",
    "trace_script = torch.jit.trace(RichFormatModel(), input)\n",
    "print(trace_script.inlined_graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5af1497",
   "metadata": {},
   "source": [
    "We can add wrappers to manually transform the input into a flattened input, and refactor the flattened output into RichFormatModel 's rich format output, suitable for tracing and downstream tasks. Hopefully we can automate the format conversion in the near future."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "75267078",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'a': tensor(1.4142), 'b': tensor(9.)}\n"
     ]
    }
   ],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "class RichFormatWrapper(torch.nn.Module):\n",
    "    def __init__(self, trace_model):\n",
    "        super().__init__()\n",
    "        self.trace_model = trace_model\n",
    "        \n",
    "    def forward(self, x:Dict[str, torch.tensor]) -> Dict[str, torch.tensor]:\n",
    "        flatten_x = x[\"a\"], x[\"b\"]\n",
    "        flatten_outputs = self.trace_model(*flatten_x)\n",
    "        return {\"a\": flatten_outputs[0], \"b\": flatten_outputs[1]}\n",
    "    \n",
    "class TraceWrapper(torch.nn.Module):\n",
    "    def __init__(self, origin_model):\n",
    "        super().__init__()\n",
    "        self.origin_model = origin_model\n",
    "    \n",
    "    def forward(self, *x: Tuple[torch.tensor]) -> Tuple[torch.tensor]:\n",
    "        dict_inputs = {\"a\": x[0], \"b\": x[1]}\n",
    "        dict_outputs = self.origin_model(dict_inputs)\n",
    "        flatten_outputs = dict_outputs[\"a\"], dict_outputs[\"b\"]\n",
    "        return flatten_outputs\n",
    "    \n",
    "trace_model = TraceWrapper(RichFormatModel())\n",
    "flatten_inputs = input[\"a\"], input[\"b\"]\n",
    "trace_script = torch.jit.trace(trace_model, flatten_inputs)\n",
    "new_model = RichFormatWrapper(trace_script)\n",
    "outputs = new_model(input)\n",
    "print(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e9dee92",
   "metadata": {},
   "source": [
    "### 3. How to pass jit test ?"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59ca7a33",
   "metadata": {},
   "source": [
    "step 1:  Do torch.jit.trace test, refer to https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch+jit+trace#torch.jit.trace. If you encounter the error \"TracingCheckError: Tracing failed sanity checks!\". This means that if your model trace twice with the same inputs, it will get a different graph. You can set \"check_trace=False\" to walk around it."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a00def71",
   "metadata": {},
   "source": [
    "step 2: Use the trace script for evaluation testing to ensure that the trace script behaves correctly. If you have any problems with the evaluation test. The reason is that the traced script may depend on the traced input, you can try to trace with real data instead of dummy data. If the problem persists, check and modify your model to be independent of specific inputs."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
